{%block imports %}
import unittest
from os.path import dirname, abspath, join
from manticore.platforms import wasm
from manticore.wasm import ManticoreWASM
from manticore.wasm.types import I32, F32, I64, F64, Trap, FunctionType
from manticore.wasm.structure import TableInst, MemInst, GlobalInst, HostFunc
from manticore.core.plugin import Plugin
from manticore.core.state import TerminateState
from base64 import b64decode
import math


from manticore.utils import config
consts = config.get_group("core")
consts.mprocessing = consts.mprocessing.single

def assertEqualNan(testcase, a_list, b):
    """ b_list: list of exected return values
        a_list: 3-nested lists: One for each state, containing one for each of the n possible return values, each containing the output of solver.get_all_values """
    assert len(a_list) == 1, f"Diverged into {len(a_list)} states!"
    for state_a in a_list:
        assert len(state_a) == len(b), f"Returned {len(a)} values! (Expected {len(b)})"
        for a in state_a:
            assert len(a) <= 1, "Found multiple possible values!"
            for a_item, b_item in zip(a, b):
                if isinstance(b_item, F32) and isinstance(a_item, int):
                    a_item = F32.cast(a_item)
                if isinstance(b_item, F64) and isinstance(a_item, int):
                    a_item = F64.cast(a_item)
                if math.isnan(b_item):
                    testcase.assertTrue(math.isnan(a_item))
                else:
                    testcase.assertAlmostEqual(a_item, b_item)
{% endblock %}

def noarg(_ctx):
  return []
def onearg(_ctx, _c):
  return []
def twoarg(_ctx, _x, _y):
  return []

spectest = {
  "print": HostFunc(FunctionType([], []), noarg),
  "print_i32": HostFunc(FunctionType([I32], []), onearg),
  "print_i32_f32": HostFunc(FunctionType([I32, F32], []), twoarg),
  "print_f64_f64": HostFunc(FunctionType([F64, F64], []), twoarg),
  "print_f32": HostFunc(FunctionType([F32], []), onearg),
  "print_f64": HostFunc(FunctionType([F64], []), onearg),
  "global_i32": GlobalInst(I32(666), True),
  "global_f32": GlobalInst(F32(666), True),
  "global_f64": GlobalInst(F64(666), True),
  "table": TableInst([None] * 10, 20),
  "memory": MemInst([0x0] * 64 * 1024, 2)
}

class RaiseExceptionPlugin(Plugin):

    def did_terminate_state_callback(self, current_state, exc):
        if "raised Trap" in str(exc):
            self.manticore.trapped = True

{% for module in modules %}
class WASMTest_{{ module.name }}(unittest.TestCase):
    _multiprocess_can_split_ = False
    filename = join(dirname(abspath(__file__)), "{{ module.filename }}")
    subtest_count = 0

    def run(self, result=None):
        result = super().run(result)
        setattr(result, "testsRun", self.subtest_count + getattr(result, "testsRun", 0) - 1)
        return result

    def test_{{ module.name }}(self):
        m = ManticoreWASM(self.filename, sup_env={"spectest":spectest}, exec_start=True)
        m.register_plugin(RaiseExceptionPlugin())

        {% for test in module.tests %}
        name = {{ test.func | escape_null }} + "_L{{test.line}}"
        with self.subTest(msg=name):
            print("======V======", name, "(Symbolic)", "======V======")
            self.subtest_count += 1

            if self.subtest_count % 50 == 0:
                if {{module.allow_reinit}}:
                    print("Reinitializing Manticore object after 50 symbolic tests")
                    m = ManticoreWASM(self.filename, sup_env={"spectest":spectest}, exec_start=True)
                    m.register_plugin(RaiseExceptionPlugin())

            def create_argv(state):
                argv = []
                {% for arg in test.args %}
                argv.append({{arg.constraint}})
                state.constrain(argv[-1] == {{arg.val}})
                {% endfor %}
                return argv

            expected = [{{ test.rets }}]
            {% if test.type == "assert_return" %}
            with m.kill_timeout(150):
                m.invoke(name={{ test.func | escape_null }}, argv_generator=create_argv)
                setattr(m, 'trapped', False)
                m.run()
                self.assertFalse(getattr(m, 'trapped', False))
                if not getattr(m, '_killed', False):
                    rets = m.collect_returns(len(expected))
                    m._reinit()
                    assertEqualNan(self, rets, expected)
                else:
                    m._killed.value = False
                    print(name, "Timed Out, Reinitializing Manticore")
                    m = ManticoreWASM(self.filename, sup_env={"spectest":spectest}, exec_start=True)
                    m.register_plugin(RaiseExceptionPlugin())
            {% endif %}
            {% if test.type == "assert_trap" %}
            with m.kill_timeout(150):
                m.invoke(name={{ test.func | escape_null }}, argv_generator=create_argv)
                m.run()
                self.assertTrue(getattr(m, 'trapped', False))
                if {{module.allow_reinit}}:
                    print("Reinitializing after Trap")
                    m = ManticoreWASM(self.filename, sup_env={"spectest":spectest}, exec_start=True)
                    m.register_plugin(RaiseExceptionPlugin())
                else:
                    m._reinit()
            {% endif %}
            {% if test.type == "action" %}
            with m.kill_timeout(150):
                m.invoke(name={{ test.func | escape_null }}, argv_generator=create_argv)
                m.run()
                m._reinit()
            {% endif %}
        {% endfor %}

{% endfor %}

if __name__ == "__main__":
    unittest.main()
